'''
    main process for retrain a subnetwork from beginning
'''
from distutils.command.build import build
import os
import pdb
import time
import math
import pickle
import random
import argparse
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from tqdm import tqdm
from thop import profile
import json

import torch
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
# import torchvision.models as models
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

# https://github.com/microsoft/mup
import mup
from mup import make_base_shapes, set_base_shapes, MuSGD, MuAdam

# from models import model_dict
import models
from models.tiny_network import code2arch_str
from dag_utils import build_model, dag2affinity, effective_depth_width
from dataset import cifar10_dataloaders, cifar100_dataloaders, svhn_dataloaders, mnist_dataloaders, imagenet_dataloaders, imagenet_16_120_dataloaders
from logger import prepare_seed, prepare_logger
from utils import save_checkpoint, warmup_lr, AverageMeter, accuracy

from pruner import check_sparsity, pruning_model_random, pruning_model_random_layer_specified, prune_model_custom, extract_mask

from pdb import set_trace as bp

__all__ = ["auto_scale_lr_depth", "auto_scale_lr_depth_width", "code2depth_201"]

model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch Training Subnetworks')

##################################### Dataset #################################################
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
parser.add_argument('--img-size', type=int, default=None, help='width/height of input image')

##################################### Architecture ############################################
parser.add_argument('--arch', type=str, default='mlp', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
parser.add_argument('--dag', type=str, default=None, help='from-to edges separated by underscore. 0: broken edge; 1: skip-connect; 2: linear or conv')
parser.add_argument('--dags', type=str, nargs='+', default=None, help='from-to edges separated by underscore. 0: broken edge; 1: skip-connect; 2: linear or conv')
parser.add_argument('--imagenet_arch', action="store_true", help="back to imagenet architecture (conv1, maxpool)")
parser.add_argument('--width', type=int, default=None, help='hidden width (for mlp, cnn)')
parser.add_argument('--widths', type=list, default=None, help='hidden widths (for sss in BATS)')
parser.add_argument('--bn', action="store_true", help="use BN")
# parser.add_argument('--bias', action="store_true", help="use BN")
parser.add_argument('--fc', action="store_true", help="use FC in ConvNet")
parser.add_argument('--kernel_size', type=int, default=3, help='kernel sizel')
parser.add_argument('--stride', type=int, default=1, help='stride')
parser.add_argument('--rand_prune', default=-1, type=float, help="random global unstructured: fraction of parameters to prune")
parser.add_argument('--mup', action="store_true", help="use MuP setting")
parser.add_argument('--adam', action="store_true", help="use Adam")
parser.add_argument('--act', type=str, default='relu', choices=['relu', 'gelu'])

##################################### General setting ############################################
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--workers', type=int, default=4, help='number of workers in dataloader')
parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint file')
parser.add_argument('--inference', action="store_true", help="testing")
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='./experiment', type=str)
parser.add_argument('--exp_name', help='additional names for experiment', default='', type=str)
parser.add_argument('--repeat', default=1, type=int, help='repeat training of DAG w. different random seed')
parser.add_argument('--reverse_order', action="store_true", help="bulk train in reverse order")
parser.add_argument('--start_idx', type=int, default=-1, help='index of first dag to train (inclusive)')
parser.add_argument('--end_idx', type=int, default=-1, help='index of last dag to train (EXclusive)')

##################################### Training setting #################################################
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', default=None, type=float, help='initial learning rate')
parser.add_argument('--lr_autoscale', action="store_true", help="automatically re-scale LR")
parser.add_argument('--momentum', default=0., type=float, help='momentum') # 0.9
parser.add_argument('--weight_decay', default=0, type=float, help='weight decay') # 1e-4
parser.add_argument('--epochs', default=None, type=int, help='number of total epochs to run')
parser.add_argument('--steps', default=None, type=int, help='number of total steps to run (within 1 epoch)')
parser.add_argument('--nesterov', action="store_true", help="use nesterov")
parser.add_argument('--aug', action="store_true", help="use augmentation")
parser.add_argument('--cutout', type=int, default=-1, help='cutout length')
parser.add_argument('--warmup', default=0, type=int, help='warm up epochs')
# parser.add_argument('--decreasing_lr', default=None, help='decreasing strategy')
parser.add_argument('--decreasing_lr', action="store_true", help='decreasing strategy')
parser.add_argument('--save_ckeckpoint_freq', default=-1, type=int, help='save intermediate checkpoint per epoch')
parser.add_argument('--start_idx_lr', type=int, default=-1, help='index of first dag to train (inclusive)')
parser.add_argument('--end_idx_lr', type=int, default=-1, help='index of last dag to train (EXclusive)')
parser.add_argument('--pretrained', action="store_true", help="use official pretrained checkpoint")


# ING_GAP_SEC = 28800 # if an ".ing" file is idled for over ING_GAP_SEC second, then this job is killed, can resume; otherwise, there is still a running job
ING_GAP_SEC = 14400 # if an ".ing" file is idled for over ING_GAP_SEC second, then this job is killed, can resume; otherwise, there is still a running job
def check_modified_time(filename):
    (mode, ino, dev, nlink, uid, gid, size, atime, mtime, ctime) = os.stat(filename)
    return mtime


def auto_scale_lr_depth(lr_base, depth_base, depth):
    return lr_base * (1. * depth / depth_base) ** (-1.5)


# def auto_scale_lr_depth_width(lr_base, depth_base, depth, width_base, width):
#     return lr_base * (1. * depth / depth_base) ** (-1.5) * (width_base / max(width, 1))
def auto_scale_lr_depth_width(lr_base, depths_base, depths):
    return lr_base * ((np.array(depths)**3).sum() ** -0.5) / ((np.array(depths_base)**3).sum() ** -0.5)


def code2depth_201(code):
    arch = code2arch_str(code)
    edges = arch.split('+')
    edges = [edge[1:-1].split("|") for edge in edges]
    edge1 = '+'.join([edges[0][0], edges[1][1], edges[2][2]])
    edge2 = '+'.join([edges[0][0], edges[2][1]])
    edge3 = '+'.join([edges[1][0], edges[2][2]])
    edge4 = '+'.join([edges[2][0]])
    depths = []
    # width = 0
    for edge in [edge1, edge2, edge3, edge4]:
        if "none" in edge: continue
        # if "nor_conv" in edge: width += 1
        depths.append(edge.count("nor_conv"))
    if len(depths) == 0:
        return 0 #, 0
    # else:
        # return np.mean(depths), width / np.mean(depths) if np.mean(depths) > 0 else 0
    return max(depths)


def main():

    # best_acc = 0
    global args
    args = parser.parse_args()

    if args.widths:
        args.width = args.widths # overwrite width as widths for sss in NATS

    if args.seed is None:
        args.seed = random.randint(0, 999)

    # global best_acc
    torch.cuda.set_device(int(args.gpu))

    if args.dataset in ['cifar10', 'cifar100']:
        from torchvision.datasets import CIFAR10, CIFAR100
        CIFAR10(args.data, train=True, download=True)
        CIFAR10(args.data, train=False, download=True)
        CIFAR100(args.data, train=True, download=True)
        CIFAR100(args.data, train=False, download=True)

    if args.arch in ['mlp', 'cnn']:
        DEPTH_BASE = 3
        DEPTH_MULTIPLIER = 1
        # with open('all_dags_str.json') as json_file: # TODO
        with open('all_dags_str_N5.json') as json_file:
            all_dags_str = json.load(json_file)
        if args.dags:
            random_dag_list = []
            for _dag in args.dags:
                if _dag in all_dags_str:
                    random_dag_list.append(all_dags_str.index(_dag))
                else:
                    all_dags_str.append(_dag)
                    random_dag_list.append(len(all_dags_str)-1)
        else:
            # random_dag_list = np.load("random_dag_list.npy") # TODO
            random_dag_list = np.load("random_dag_list_N5.npy")
            if args.start_idx >= 0 and args.end_idx > 0 and args.end_idx > args.start_idx:
                random_dag_list = random_dag_list[args.start_idx:args.end_idx]
            if args.reverse_order:
                random_dag_list = random_dag_list[::-1]
    elif args.arch.startswith('tinynetwork'):
        # DEPTH_BASE = 6
        DEPTHS_BASE = [1] # 1_01_001
        if args.arch == "tinynetworksize":
            DEPTH_MULTIPLIER = 1
            with open('arch_str_nats_sss.json') as json_file: # dag code
                all_dags_str = json.load(json_file)
        else:
            DEPTH_MULTIPLIER = 3
            with open('arch_code_201.json') as json_file: # dag code
                all_dags_str = json.load(json_file)
        if args.dags:
            random_dag_list = []
            for _dag in args.dags:
                random_dag_list.append(all_dags_str.index(_dag))
        else:
            if args.arch == "tinynetworksize":
                random_dag_list = np.load("arch_indice_nats_sss.npy") # random list of index
            else:
                random_dag_list = np.load("arch_indice_201.npy") # random list of index
                # random_dag_list = np.load("arch_indice_201_elites_cifar100.npy") # descending sorted by CIFAR-100 test accuracy
            if args.start_idx >= 0 and args.end_idx > 0 and args.end_idx > args.start_idx:
                random_dag_list = random_dag_list[args.start_idx:args.end_idx]
            if args.reverse_order:
                random_dag_list = random_dag_list[::-1]
    else:
        all_dags_str = None
        random_dag_list = [args.arch]

    if args.lr:
        target_lr_list = [args.lr]
    else:
        # target_lr_list = np.load("target_lr_list_0.003.30.0.172726.npy")
        # target_lr_list = np.load("target_lr_list_0.003.89.0.074.npy") # mlp, cnn
        # target_lr_list = np.load("target_lr_list_0.014.50.0.0385.npy") # mlp for ImageNet
        # target_lr_list = np.load("target_lr_list_0.003.89.0.074.npy")[:30] # cnn K7
        target_lr_list = np.load("target_lr_list_0.005.30.1.597368.npy") # tinynetwork
        if args.start_idx_lr >= 0 and args.end_idx_lr > 0 and args.end_idx_lr > args.start_idx_lr:
            target_lr_list = target_lr_list[args.start_idx_lr:args.end_idx_lr]

    if args.steps:
        args.epochs = 1

    job_name = "BULK-{dataset}-{arch}{kernel}{width}{bn}{mup}-LR{lr}{scale}{nesterov}-BS{batch_size}{prune}-Epoch{epoch}{exp_name}".format(
        dataset=args.dataset + (".%d"%args.img_size if args.img_size else "") + (".Aug" if args.aug else "") + (".Cut%d"%args.cutout if args.cutout > 0 else ""),
        arch=args.arch + (".GeLU" if args.act == 'gelu' else ""),
        kernel=".K%d"%args.kernel_size if args.arch == "cnn" else "",
        # dag=".%s"%str(args.dag) if args.dag else "",
        width=".W%d"%args.width if args.width else "", bn=".BN" if args.bn else "",
        mup=".MuP" if args.mup else "",
        lr="%f"%args.lr if args.lr else "%f.%d.%f"%(target_lr_list[0], len(target_lr_list), target_lr_list[-1]),
        scale=".AutoScale" if args.lr_autoscale else "",
        nesterov=".Nev" if args.nesterov else "",
        batch_size=args.batch_size,
        epoch=str(args.epochs) + (".Steps%d"%args.steps if args.steps else ""),
        prune="-prune%.2f"%args.rand_prune if args.rand_prune > 0 else "", #  random pruning for now
        exp_name="" if args.exp_name == "" else "-"+args.exp_name,
        # timestamp=timestamp
    ) #, seed=args.seed)
    SAVE_DIR = os.path.join(args.save_dir, job_name)

    PID = os.getpid()
    print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, SAVE_DIR))

    pbar = tqdm(random_dag_list, position=0, leave=True)
    for dag_idx in pbar:
        if args.arch in ['mlp', 'cnn', 'tinynetwork', 'tinynetworksize']:
            if args.arch == "tinynetworksize":
                # "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|"
                args.dag = "3_33_133"
                args.width = [int(v) for v in all_dags_str[dag_idx].split(':')]
            else:
                args.dag = all_dags_str[dag_idx]
        if not os.path.exists(os.path.join(SAVE_DIR, "%d"%dag_idx)):
            os.makedirs(os.path.join(SAVE_DIR, "%d"%dag_idx))
        with open(os.path.join(SAVE_DIR, "%d"%dag_idx, "dag_str.txt"), 'w') as f:
            f.write("%s"%args.dag)
        for lr in target_lr_list:
            if args.lr_autoscale:
                _aff = dag2affinity([ [int(value) for value in node] for node in all_dags_str[dag_idx].split("_")])
                _, _width, _paths, _depth = effective_depth_width(_aff) # max depth
                if args.arch == "tinynetwork": _depths = [1+sum([_p > 0 for _p in _path]) for _path in _paths] # compensate for skip connection
                lr = auto_scale_lr_depth_width(lr, DEPTHS_BASE, _depths)
            for r_idx in range(args.repeat):
                seed = args.seed + r_idx

                if args.arch in ['mlp', 'cnn', 'tinynetwork', 'tinynetworksize']:
                    ing_file_name = "%s"%(os.path.join(SAVE_DIR, "dag%d.lr%f.seed%d.ing"%(dag_idx, lr, seed)))
                    args.save_dir = os.path.join(SAVE_DIR, "%d"%dag_idx, "%f"%lr, "%d"%seed)
                else:
                    ing_file_name = "%s"%(os.path.join(SAVE_DIR, "%s.lr%f.seed%d.ing"%(dag_idx, lr, seed)))
                    args.save_dir = os.path.join(SAVE_DIR, "%f"%lr, "%d"%seed)
                if (os.path.isfile(ing_file_name) and abs(time.time() - check_modified_time(ing_file_name)) < ING_GAP_SEC) or ((not os.path.isfile(ing_file_name)) and os.path.exists(args.save_dir)):
                    # this dag-lr is running or finished
                    pbar.set_description("Skip %s LR %f seed %d"%(str(dag_idx), lr, seed))
                    continue

                if os.path.isfile(ing_file_name):
                    prefix = "Resume"
                else:
                    prefix = "Train"

                prepare_seed(seed)
                pbar.set_description("%s %s LR=%f seed %d"%(prefix, dag_idx, lr, seed))
                train_model(dag_idx, lr, ing_file_name)
                os.system("rm %s"%ing_file_name)


def train_model(dag_idx, lr, ing_file_name):
    global args
    logger = prepare_logger(args, verbose=False)
    os.system("touch %s"%ing_file_name)

    if not args.inference:
        os.makedirs(args.save_dir, exist_ok=True)

    # prepare dataset
    NUM_VAL_IMAGE = 50
    c_in = 3
    if args.dataset == 'cifar10':
        classes = 10
        if args.img_size:
            dummy_shape = (3, args.img_size, args.img_size)
        else:
            dummy_shape = (3, 32, 32)
        train_loader, val_loader, test_loader = cifar10_dataloaders(
            batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers,
            aug=args.aug, cutout=args.cutout,
            flatten=args.arch == "mlp", resize=args.img_size,
            crossval=False # TODO
        )
    elif args.dataset == 'cifar100':
        classes = 100
        dummy_shape = (3, 32, 32)
        train_loader, val_loader, test_loader = cifar100_dataloaders(
            batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers,
            aug=args.aug, cutout=args.cutout,
            flatten=args.arch == "mlp", resize=args.img_size,
            crossval=False # TODO
        )
    elif args.dataset == 'imagenet16_120':
        classes = 120
        dummy_shape = (3, 16, 16)
        train_loader, val_loader, test_loader = imagenet_16_120_dataloaders(
            batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers,
            aug=args.aug, cutout=args.cutout,
            flatten=args.arch == "mlp", resize=args.img_size
        )
    elif args.dataset == 'svhn':
        classes = 10
        dummy_shape = (3, 32, 32)
        train_loader, val_loader, test_loader = svhn_dataloaders(batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers, flatten=args.arch == "mlp")
    elif args.dataset == 'mnist':
        c_in = 1
        classes = 10
        dummy_shape = (1, 28, 28)
        train_loader, val_loader, test_loader = mnist_dataloaders(batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers, flatten=args.arch == "mlp")
    elif args.dataset == 'tinyimagenet':
        classes = 200
        dummy_shape = (3, 64, 64)
        train_loader, val_loader, test_loader = imagenet_dataloaders(batch_size = args.batch_size, data_dir = args.data, num_workers = args.workers, flatten=args.arch == "mlp")
    elif args.dataset == 'imagenet':
        classes = 1000
        dummy_shape = (3, 224, 224)
        train_loader, val_loader, test_loader = imagenet_dataloaders(batch_size = args.batch_size, img_shape=dummy_shape[1], data_dir = args.data, num_workers = args.workers, flatten=args.arch == "mlp")
    elif args.dataset is None:
        pass
    else:
        raise ValueError('Dataset not supprot yet!')

    model = build_model(args, classes, dummy_shape)
    logger.log(str(model))

    if args.mup:
        base_model = build_model(args, classes, dummy_shape, width=1)
        delta_model = build_model(args, classes, dummy_shape, width=2)
        set_base_shapes(model, base_model, delta=delta_model)
        for param in model.parameters():
            ### If initializing manually with fixed std or bounds,
            ### then replace with same function from mup.init
            # torch.nn.init.uniform_(param, -0.1, 0.1)
            mup.init.uniform_(param, -0.1, 0.1)
            ### Likewise, if using
            ###   `xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_`
            ### from `torch.nn.init`, replace with the same functions from `mup.init`

    # setup initialization and mask
    if args.rand_prune > 0:
        pruning_model_random_layer_specified(model, min(1, args.rand_prune)) #, conv1=args.prune_conv1)
        remain_weight_rate = check_sparsity(model)
        logger.log("remaining weight rate: %.2f"%remain_weight_rate)

    model = model.cuda()

    criterion = nn.CrossEntropyLoss()
    # decreasing_lr = list(map(int, args.decreasing_lr.split(','))) if args.decreasing_lr else None

    if args.mup:
        optimizer = MuSGD(model.parameters(), lr=0.1)
    elif args.adam:
        optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
    scheduler = None
    # if decreasing_lr:
    if args.decreasing_lr:
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

    if args.inference:
        # test
        if args.checkpoint:
            checkpoint = torch.load(args.checkpoint, map_location = torch.device('cuda:'+str(args.gpu)))
            if 'state_dict' in checkpoint.keys():
                checkpoint = checkpoint['state_dict']
            model.load_state_dict(checkpoint)

        test_acc = validate(test_loader, model, criterion, 0)
        logger.log('* Test Accuracy = {}'.format(test_acc))
        return 0

    # if args.resume:
    ckpt_path = "%s/checkpoint.pth.tar"%args.save_dir
    if os.path.isfile(ckpt_path):
        logger.log('resume from checkpoint {}'.format(ckpt_path))
        checkpoint = torch.load(ckpt_path, map_location = torch.device('cuda:'+str(args.gpu)))
        # best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        all_result = checkpoint['result']

        # optimizer = torch.optim.SGD(model.parameters(), args.lr,
        optimizer = torch.optim.SGD(model.parameters(), lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = None
        # if decreasing_lr:
        if args.decreasing_lr:
            # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
            scheduler.load_state_dict(checkpoint['scheduler'])

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.log('loading from epoch: ',start_epoch)#, 'best_acc=', best_acc)
    else:
        all_result = {}
        all_result['train_acc'] = []
        all_result['test_acc'] = []
        all_result['val_acc'] = []
        all_result['train_loss'] = []
        all_result['test_loss'] = []
        all_result['val_loss'] = []
        start_epoch = 0

    logger.log("Path {}".format(args.save_dir))
    overfit = 0
    OVERFIT = 10
    prev_val_loss = float(math.inf)
    loss_steps_epochs = []
    for epoch in range(start_epoch, args.epochs):
        os.system("touch %s"%ing_file_name)

        train_loss, train_acc, loss_steps = train(train_loader, model, criterion, optimizer, epoch, args.steps)
        loss_steps_epochs += list(loss_steps)
        logger.writer.add_scalar("train/loss", train_loss, epoch)
        logger.writer.add_scalar("train/accuracy", train_acc, epoch)
        pbar_str = "Epoch:{} Train:{:.2f} (Loss:{:.4f}) ".format(epoch, train_acc, train_loss)
        all_result['train_acc'].append(train_acc)
        all_result['train_loss'].append(train_loss)
        if val_loader:
            val_loss, val_acc = validate(val_loader, model, criterion, epoch, split="Val")
            logger.writer.add_scalar("validation/loss", val_loss, epoch)
            logger.writer.add_scalar("validation/accuracy", val_acc, epoch)
            pbar_str += "Validation:{:.2f} (Loss:{:.4f})".format(val_acc, val_loss)
            all_result['val_acc'].append(val_acc)
            all_result['val_loss'].append(val_loss)
            plt.plot(all_result['val_acc'], label='val_acc')
        if test_loader:
            test_loss, test_acc = validate(test_loader, model, criterion, epoch, split="Test")
            logger.writer.add_scalar("test/loss", test_loss, epoch)
            logger.writer.add_scalar("test/accuracy", test_acc, epoch)
            pbar_str += "Test:{:.2f} (Loss:{:.4f})".format(test_acc, test_loss)
            all_result['test_acc'].append(test_acc)
            all_result['test_loss'].append(test_loss)
            plt.plot(all_result['test_acc'], label='test_acc')
        logger.log("Path {}".format(args.save_dir))
        logger.log(pbar_str)
        if scheduler:
            logger.log("LR:{}".format(scheduler.get_last_lr()[0]))

        if scheduler: scheduler.step()

        checkpoint = {
            'result': all_result,
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'train_losses': loss_steps_epochs
        }
        save_checkpoint(checkpoint, is_best=False, save_path=args.save_dir)

        # plot training curve
        plt.plot(all_result['train_acc'], label='train_acc')
        plt.legend()
        plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
        plt.close()

        if np.isnan(train_loss):
            all_result['train_acc'] += [train_acc] * (args.epochs - epoch - 1)
            all_result['train_loss'] += [train_loss] * (args.epochs - epoch - 1)
            break

    #report result
    if val_loader:
        val_pick_best_epoch = np.argmax(np.array(all_result['val_acc']))
    else:
        val_pick_best_epoch = len(all_result['train_acc']) - 1
    if test_loader:
        best_acc = all_result['test_acc'][val_pick_best_epoch]
        best_loss = all_result['test_loss'][val_pick_best_epoch]
    else:
        best_acc = all_result['val_acc'][val_pick_best_epoch]
        best_loss = all_result['val_loss'][val_pick_best_epoch]
    logger.log('* best accuracy = {}, best loss = {}, Epoch = {}'.format(best_acc, best_loss, val_pick_best_epoch+1))
    checkpoint = {
        'result': all_result,
        'epoch': epoch + 1,
        # 'state_dict': model.state_dict(),
        'best_acc': best_acc,
        'best_loss': best_loss,
        'best_epoch': val_pick_best_epoch,
        # 'optimizer': optimizer.state_dict(),
        # 'scheduler': scheduler.state_dict() if scheduler else None,
        'train_losses': loss_steps_epochs
    }
    save_checkpoint(checkpoint, is_best=False, save_path=args.save_dir)
    # os.system("rm %s"%ckpt_path)


def train(train_loader, model, criterion, optimizer, epoch, steps):
    losses = AverageMeter()
    top1 = AverageMeter()
    loss_steps = []

    # switch to train mode
    model.train()

    start = time.time()
    for i, (image, target) in enumerate(train_loader):
        if isinstance(steps, int) and epoch == 0 and i >= steps: break
        if epoch < args.warmup:
            warmup_lr(args.warmup, args.lr, epoch, i+1, optimizer, one_epoch_step=len(train_loader))

        image = image.cuda()
        target = target.cuda()

        # compute output
        output_clean = model(image)
        loss = criterion(output_clean, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output_clean.float()
        loss = loss.float()
        prec1, gt_num = accuracy(output.data, target, topk=(1,))
        top1.update(prec1[0], gt_num[0])
        loss_steps.append(float(loss.item()))

        losses.update(loss.item(), image.size(0))

    return float(losses.avg), float(top1.vec2sca_avg), loss_steps


def validate(val_loader, model, criterion, epoch, split="Test"):
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (image, target) in enumerate(val_loader):
        image = image.cuda()
        target = target.cuda()

        # compute output
        with torch.no_grad():
            output = model(image)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        prec1, gt_num = accuracy(output.data, target, topk=(1,))
        top1.update(prec1[0], gt_num[0])
        losses.update(loss.item(), image.size(0))

    return float(losses.avg), float(top1.vec2sca_avg)


if __name__ == '__main__':
    main()
